import csv
import os
import pandas as pd
import numpy as np
from scipy import io
import torch.utils.data as data
from PIL import Image


class KONIQDATASET(data.Dataset):
    def __init__(self, root, index, patch_num, transform=None):
        super(KONIQDATASET, self).__init__()

        self.data_path = root
        imgname = []
        mos_all = []
        csv_file = os.path.join(root, "koniq10k_scores_and_distributions.csv")
        with open(csv_file) as f:
            reader = csv.DictReader(f)
            for row in reader:
                imgname.append(row["image_name"])
                mos = np.array(float(row["MOS_zscore"])).astype(np.float32)
                mos_all.append(mos)

        sample = []
        for _, item in enumerate(index):
            for _ in range(patch_num):
                sample.append(
                    (os.path.join(root, "1024x768", imgname[item]), mos_all[item])
                )

        self.samples = sample
        self.transform = transform

    def _load_image(self, path):
        try:
            im = Image.open(path).convert("RGB")
        except:
            print("ERROR IMG LOADED: ", path)
            random_img = np.random.rand(224, 224, 3) * 255
            im = Image.fromarray(np.uint8(random_img))
        return im

    def __getitem__(self, index):
        """
        Args:
            index (int): Index

        Returns:
            tuple: (sample, target) where target is class_index of the target class.
        """
        path, target = self.samples[index]
        sample = self._load_image(path)
        sample = self.transform(sample)
        return sample, target

    def __len__(self):
        length = len(self.samples)
        return length


class LIVECDATASET(data.Dataset):
    def __init__(self, root, index, patch_num, transform=None):

        imgpath = io.loadmat(os.path.join(root, "Data", "AllImages_release.mat"))
        imgpath = imgpath["AllImages_release"]
        imgpath = imgpath[7:1169]
        mos = io.loadmat(os.path.join(root, "Data", "AllMOS_release.mat"))
        labels = mos["AllMOS_release"].astype(np.float32)
        labels = labels[0][7:1169]

        sample = []
        for i, item in enumerate(index):
            for aug in range(patch_num):
                sample.append(
                    (os.path.join(root, "Images", imgpath[item][0][0]), labels[item])
                )

        self.samples = sample
        self.transform = transform

    def _load_image(self, path):
        try:
            im = Image.open(path).convert("RGB")
        except:
            print("ERROR IMG LOADED: ", path)
            random_img = np.random.rand(224, 224, 3) * 255
            im = Image.fromarray(np.uint8(random_img))
        return im

    def __getitem__(self, index):
        """
        Args:
            index (int): Index
        Returns:
            tuple: (sample, target) where target is class_index of the target class.
        """
        path, target = self.samples[index]
        sample = self._load_image(path)
        sample = self.transform(sample)
        return sample, target

    def __len__(self):
        length = len(self.samples)
        return length


class UWIQADATASET(data.Dataset):
    def __init__(self, root, index, patch_num, transform=None):

        imgpath = io.loadmat(os.path.join(root, "Data", "AllImages_release.mat"))
        imgpath = imgpath["AllImages_release"]
        imgpath = imgpath[0:890]
        mos = io.loadmat(os.path.join(root, "Data", "AllMOS_release.mat"))
        labels = mos["AllMOS_release"].astype(np.float32)
        labels = labels[0][0:890]

        sample = []
        for i, item in enumerate(index):
            for aug in range(patch_num):
                sample.append(
                    (os.path.join(root, "Images", imgpath[item][0][0]), labels[item])
                )

        self.samples = sample
        self.transform = transform

    def _load_image(self, path):
        try:
            im = Image.open(path).convert("RGB")
        except:
            print("ERROR IMG LOADED: ", path)
            random_img = np.random.rand(224, 224, 3) * 255
            im = Image.fromarray(np.uint8(random_img))
        return im

    def __getitem__(self, index):
        """
        Args:
            index (int): Index
        Returns:
            tuple: (sample, target) where target is class_index of the target class.
        """
        path, target = self.samples[index]
        sample = self._load_image(path)
        sample = self.transform(sample)
        return sample, target

    def __len__(self):
        length = len(self.samples)
        return length


class LIVEDataset(data.Dataset):
    def __init__(self, root, index, patch_num, transform=None):

        refpath = os.path.join(root, "refimgs")
        refname = getFileName(refpath, ".bmp")

        jp2kroot = os.path.join(root, "jp2k")
        jp2kname = self.getDistortionTypeFileName(jp2kroot, 227)

        jpegroot = os.path.join(root, "jpeg")
        jpegname = self.getDistortionTypeFileName(jpegroot, 233)

        wnroot = os.path.join(root, "wn")
        wnname = self.getDistortionTypeFileName(wnroot, 174)

        gblurroot = os.path.join(root, "gblur")
        gblurname = self.getDistortionTypeFileName(gblurroot, 174)

        fastfadingroot = os.path.join(root, "fastfading")
        fastfadingname = self.getDistortionTypeFileName(fastfadingroot, 174)

        imgpath = jp2kname + jpegname + wnname + gblurname + fastfadingname

        dmos = io.loadmat(os.path.join(root, "dmos_realigned.mat"))
        labels = dmos["dmos_new"].astype(np.float32)

        orgs = dmos["orgs"]
        refnames_all = io.loadmat(os.path.join(root, "refnames_all.mat"))
        refnames_all = refnames_all["refnames_all"]

        refname.sort()
        sample = []

        for i in range(0, len(index)):
            train_sel = refname[index[i]] == refnames_all
            train_sel = train_sel * ~orgs.astype(np.bool_)
            train_sel = np.where(train_sel == True)
            train_sel = train_sel[1].tolist()
            for j, item in enumerate(train_sel):
                for aug in range(patch_num):
                    sample.append((imgpath[item], labels[0][item]))
        self.samples = sample
        self.transform = transform

    def _load_image(self, path):
        try:
            im = Image.open(path).convert("RGB")
        except:
            print("ERROR IMG LOADED: ", path)
            random_img = np.random.rand(224, 224, 3) * 255
            im = Image.fromarray(np.uint8(random_img))
        return im

    def __getitem__(self, index):
        """
        Args:
            index (int): Index
        Returns:
            tuple: (sample, target) where target is class_index of the target class.
        """
        path, target = self.samples[index]
        sample = self._load_image(path)
        if self.transform is not None:
            sample = self.transform(sample)

        return sample, target

    def __len__(self):
        length = len(self.samples)
        return length

    def getDistortionTypeFileName(self, path, num):
        filename = []
        index = 1
        for i in range(0, num):
            name = "%s%s%s" % ("img", str(index), ".bmp")
            filename.append(os.path.join(path, name))
            index = index + 1
        return filename


def getFileName(path, suffix):
    filename = []
    f_list = os.listdir(path)
    for i in f_list:
        if os.path.splitext(i)[1] == suffix:
            filename.append(i)
    return filename


class TID2013Dataset(data.Dataset):
    def __init__(self, root, index, patch_num, transform=None):
        refpath = os.path.join(root, "reference_images")
        refname = getTIDFileName(refpath, ".bmp.BMP")
        txtpath = os.path.join(root, "mos_with_names.txt")
        fh = open(txtpath, "r")
        imgnames = []
        target = []
        refnames_all = []
        for line in fh:
            line = line.split("\n")
            words = line[0].split()
            imgnames.append((words[1]))
            target.append(words[0])
            ref_temp = words[1].split("_")
            refnames_all.append(ref_temp[0][1:])
        labels = np.array(target).astype(np.float32)
        refnames_all = np.array(refnames_all)

        refname.sort()
        sample = []
        for i, item in enumerate(index):
            train_sel = refname[index[i]] == refnames_all
            train_sel = np.where(train_sel == True)
            train_sel = train_sel[0].tolist()
            for j, item in enumerate(train_sel):
                for aug in range(patch_num):
                    sample.append(
                        (
                            os.path.join(root, "distorted_images", imgnames[item]),
                            labels[item],
                        )
                    )
        self.samples = sample
        self.transform = transform

    def _load_image(self, path):
        try:
            im = Image.open(path).convert("RGB")
        except:
            print("ERROR IMG LOADED: ", path)
            random_img = np.random.rand(224, 224, 3) * 255
            im = Image.fromarray(np.uint8(random_img))
        return im

    def __getitem__(self, index):
        """
        Args:
            index (int): Index
        Returns:
            tuple: (sample, target) where target is class_index of the target class.
        """
        path, target = self.samples[index]
        sample = self._load_image(path)
        sample = self.transform(sample)
        return sample, target

    def __len__(self):
        length = len(self.samples)
        return length


def getTIDFileName(path, suffix):
    filename = []
    f_list = os.listdir(path)
    for i in f_list:
        if suffix.find(os.path.splitext(i)[1]) != -1:
            filename.append(i[1:3])
    return filename


class CSIQDataset(data.Dataset):
    def __init__(self, root, index, patch_num, transform=None):
        refpath = os.path.join(root, "src_imgs")
        refname = getFileName(refpath, ".png")  # 获取参考图像文件名列表
        txtpath = os.path.join(root, "csiq_label.txt")

        fh = open(txtpath, "r")
        imgnames = []
        target = []
        refnames_all = []
        for line in fh:
            line = line.split("\n")
            words = line[0].split()
            imgnames.append(words[0])
            target.append(words[1])
            ref_temp = words[0].split(".")
            refnames_all.append(ref_temp[0] + "." + ref_temp[-1])  # 获取失真图像的参考部分

        labels = np.array(target).astype(np.float32)
        refnames_all = np.array(refnames_all)

        sample = []

        # 新增：在初始化中打印调试信息
        print(f"Reference images: {refname[:5]}")  # 打印参考图像的前5个文件名
        print(f"Distorted images: {imgnames[:5]}")  # 打印失真图像的前5个文件名

        for i, item in enumerate(index):
            if item >= len(refname):
                print(f"Warning: Index {item} out of bounds for reference images (size {len(refname)})")
                continue

            # 新增：调整匹配逻辑，支持部分匹配
            ref_base_name = os.path.splitext(refname[item])[0]  # 去掉参考图像文件的扩展名
            train_sel = [j for j, distorted in enumerate(refnames_all) if ref_base_name in distorted]

            if len(train_sel) == 0:
                print(f"Warning: No matching distorted images found for reference image {refname[item]}")
                continue

            for j, sel_item in enumerate(train_sel):
                for aug in range(patch_num):
                    # 确保图像路径包含正确的扩展名
                    image_path = os.path.join(root, "dst_imgs_all", imgnames[sel_item])
                    if not image_path.endswith(".png"):
                        image_path += ".png"

                    if not os.path.exists(image_path):
                        print(f"Warning: Image not found: {image_path}")
                        continue

                    sample.append((image_path, labels[sel_item]))

        self.samples = sample
        self.transform = transform

        print(f"Created dataset with {len(self.samples)} samples")

    def _load_image(self, path):
        try:
            im = Image.open(path).convert("RGB")
        except Exception as e:
            print(f"Error loading image {path}: {e}")
            random_img = np.random.rand(224, 224, 3) * 255
            im = Image.fromarray(np.uint8(random_img))
        return im

    def __getitem__(self, index):
        """
        Args:
            index (int): Index
        Returns:
            tuple: (sample, target) where target is class_index of the target class.
        """
        path, target = self.samples[index]
        sample = self._load_image(path)
        sample = self.transform(sample)

        return sample, target

    def __len__(self):
        return len(self.samples)

# class CSIQDataset(data.Dataset):
#     def __init__(self, root, index, patch_num, transform=None):
#
#         refpath = os.path.join(root, "src_imgs")
#         refname = getFileName(refpath, ".png")
#         txtpath = os.path.join(root, "csiq_label.txt")
#         fh = open(txtpath, "r")
#         imgnames = []
#         target = []
#         refnames_all = []
#         for line in fh:
#             line = line.split("\n")
#             words = line[0].split()
#             imgnames.append((words[0]))
#             target.append(words[1])
#             ref_temp = words[0].split(".")
#             refnames_all.append(ref_temp[0] + "." + ref_temp[-1])  #
#
#         labels = np.array(target).astype(np.float32)
#         refnames_all = np.array(refnames_all)
#
#         sample = []
#
#         for i, item in enumerate(index):
#             train_sel = refname[index[i]] == refnames_all
#             train_sel = np.where(train_sel == True)
#             train_sel = train_sel[0].tolist()
#             for j, item in enumerate(train_sel):
#                 for aug in range(patch_num):
#                     sample.append(
#                         (
#                             os.path.join(root, "dst_imgs_all", imgnames[item] + ".png"),
#                             labels[item],
#                         )
#                     )
#         self.samples = sample
#         self.transform = transform
#
#     def _load_image(self, path):
#         try:
#             im = Image.open(path).convert("RGB")
#         except:
#             print("ERROR IMG LOADED: ", path)
#             random_img = np.random.rand(224, 224, 3) * 255
#             im = Image.fromarray(np.uint8(random_img))
#         return im
#
#     def __getitem__(self, index):
#         """
#         Args:
#             index (int): Index
#         Returns:
#             tuple: (sample, target) where target is class_index of the target class.
#         """
#         path, target = self.samples[index]
#         sample = self._load_image(path)
#         sample = self.transform(sample)
#
#         return sample, target
#
#     def __len__(self):
#         length = len(self.samples)
#         return length


class KADIDDataset(data.Dataset):
    def __init__(self, root, index, patch_num, transform=None):
        refpath = os.path.join(root, "reference_images")
        refname = getTIDFileName(refpath, ".png.PNG")

        imgnames = []
        target = []
        refnames_all = []

        csv_file = os.path.join(root, "dmos.csv")
        with open(csv_file) as f:
            reader = csv.DictReader(f)
            for row in reader:
                imgnames.append(row["dist_img"])
                refnames_all.append(row["ref_img"][1:3])

                mos = np.array(float(row["dmos"])).astype(np.float32)
                target.append(mos)

        labels = np.array(target).astype(np.float32)
        refnames_all = np.array(refnames_all)

        refname.sort()
        sample = []
        for i, item in enumerate(index):
            train_sel = refname[index[i]] == refnames_all
            train_sel = np.where(train_sel == True)
            train_sel = train_sel[0].tolist()
            for j, item in enumerate(train_sel):
                for _ in range(patch_num):
                    sample.append(
                        (
                            os.path.join(root, "distorted_images", imgnames[item]),
                            labels[item],
                        )
                    )
        self.samples = sample
        self.transform = transform

    def _load_image(self, path):
        try:
            im = Image.open(path).convert("RGB")
        except:
            print("ERROR IMG LOADED: ", path)
            random_img = np.random.rand(224, 224, 3) * 255
            im = Image.fromarray(np.uint8(random_img))
        return im

    def __getitem__(self, index):
        """
        Args:
            index (int): Index
        Returns:
            tuple: (sample, target) where target is class_index of the target class.
        """
        path, target = self.samples[index]
        sample = self._load_image(path)
        sample = self.transform(sample)
        return sample, target

    def __len__(self):
        length = len(self.samples)
        return length


class SPAQDATASET(data.Dataset):
    def __init__(self, root, index, patch_num, transform=None):
        super(SPAQDATASET, self).__init__()

        self.data_path = root
        anno_folder = os.path.join(self.data_path, "Annotations")
        xlsx_file = os.path.join(anno_folder, "MOS and Image attribute scores.xlsx")
        read = pd.read_excel(xlsx_file)
        imgname = read["Image name"].values.tolist()
        mos_all = read["MOS"].values.tolist()
        for i in range(len(mos_all)):
            mos_all[i] = np.array(mos_all[i]).astype(np.float32)
        sample = []
        for _, item in enumerate(index):
            for _ in range(patch_num):
                sample.append(
                    (
                        os.path.join(
                            self.data_path,
                            "SPAQ_zip",
                            "TestImage",
                            imgname[item],
                        ),
                        mos_all[item],
                    )
                )

        self.samples = sample
        self.transform = transform

    def _load_image(self, path):
        try:
            im = Image.open(path).convert("RGB")
        except:
            print("ERROR IMG LOADED: ", path)
            random_img = np.random.rand(224, 224, 3) * 255
            im = Image.fromarray(np.uint8(random_img))
        return im

    def __getitem__(self, index):
        """
        Args:
            index (int): Index

        Returns:
            tuple: (sample, target) where target is class_index of the target class.
        """
        path, target = self.samples[index]
        sample = self._load_image(path)
        sample = self.transform(sample)
        return sample, target

    def __len__(self):
        length = len(self.samples)
        return length


class FBLIVEFolder(data.Dataset):
    def __init__(self, root, index, patch_num, transform=None):
        imgname = []
        mos_all = []
        csv_file = os.path.join(root, "labels_image.csv")
        with open(csv_file) as f:
            reader = csv.DictReader(f)
            for row in reader:
                imgname.append(row["name"])
                mos = np.array(float(row["mos"])).astype(np.float32)
                mos_all.append(mos)

        sample = []
        for i, item in enumerate(index):
            for aug in range(patch_num):
                sample.append(
                    (os.path.join(root, "database", imgname[item]), mos_all[item])
                )

        self.samples = sample
        self.transform = transform
        # Define target size for crops with a safety margin
        self.target_size = (224, 224)
        self.safety_margin = 2  # Add a small safety margin to avoid rounding issues

    def _load_image(self, path):
        try:
            im = Image.open(path).convert("RGB")

            # Add a safety margin to ensure image is definitely large enough
            min_width = self.target_size[0] + self.safety_margin
            min_height = self.target_size[1] + self.safety_margin

            # Check if image is smaller than target size (+margin) and resize if necessary
            if im.width < min_width or im.height < min_height:
                # Calculate scaling ratio with safety margin
                ratio = max(min_width / im.width, min_height / im.height)

                # Calculate new dimensions and ensure they're integer values
                new_width = int(im.width * ratio)
                new_height = int(im.height * ratio)

                # Verify dimensions are at least the minimum size (with margin)
                new_width = max(new_width, min_width)
                new_height = max(new_height, min_height)

                # Resize the image
                im = im.resize((new_width, new_height), Image.BICUBIC)

        except Exception as e:
            print(f"ERROR IMG LOADED: {path}, Error: {e}")
            random_img = np.random.rand(self.target_size[1], self.target_size[0], 3) * 255
            im = Image.fromarray(np.uint8(random_img))

        return im

    def __getitem__(self, index):
        """
        Args:
            index (int): Index
        Returns:
            tuple: (sample, target) where target is class_index of the target class.
        """
        path, target = self.samples[index]
        sample = self._load_image(path)

        # Double check the dimensions before applying transforms
        if sample.width < self.target_size[0] or sample.height < self.target_size[1]:
            print(f"WARNING: Image still too small after resize: {sample.width}x{sample.height}, path: {path}")
            # Force resize to at least target size
            sample = sample.resize((max(self.target_size[0], sample.width),
                                    max(self.target_size[1], sample.height)),
                                   Image.BICUBIC)

        sample = self.transform(sample)
        return sample, target

    def __len__(self):
        length = len(self.samples)
        return length

# class FBLIVEFolder(data.Dataset):
#     def __init__(self, root, index, patch_num, transform=None):
#         imgname = []
#         mos_all = []
#         csv_file = os.path.join(root, "labels_image.csv")
#         with open(csv_file) as f:
#             reader = csv.DictReader(f)
#             for row in reader:
#                 imgname.append(row["name"])
#                 mos = np.array(float(row["mos"])).astype(np.float32)
#                 mos_all.append(mos)
#
#         sample = []
#         for i, item in enumerate(index):
#             for aug in range(patch_num):
#                 sample.append(
#                     (os.path.join(root, "database", imgname[item]), mos_all[item])
#                 )
#
#         self.samples = sample
#         self.transform = transform
#
#     def _load_image(self, path):
#         try:
#             im = Image.open(path).convert("RGB")
#         except:
#             print("ERROR IMG LOADED: ", path)
#             random_img = np.random.rand(224, 224, 3) * 255
#             im = Image.fromarray(np.uint8(random_img))
#         return im
#
#     def __getitem__(self, index):
#         """
#         Args:
#             index (int): Index
#         Returns:
#             tuple: (sample, target) where target is class_index of the target class.
#         """
#         path, target = self.samples[index]
#         sample = self._load_image(path)
#         sample = self.transform(sample)
#         return sample, target
#
#     def __len__(self):
#         length = len(self.samples)
#         return length
